Image Classification

For this problem, we will MNIST dataset. http://yann.lecun.com/exdb/mnist/


In [1]:
spark.sparkContext.uiWebUrl


Out[1]:
'http://172.17.1.79:4040'

In [4]:
import matplotlib.pyplot as plt
import pandas as pd

%matplotlib inline

In [10]:
df_training = (spark
               .read
               .options(header = False, inferSchema = True)
               .csv("data/MNIST/mnist_train.csv"))

In [71]:
df_training.count()


Out[71]:
60000

In [14]:
print("No of columns: ", len(df_training.columns), df_training.columns)


No of columns:  785 ['_c0', '_c1', '_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13', '_c14', '_c15', '_c16', '_c17', '_c18', '_c19', '_c20', '_c21', '_c22', '_c23', '_c24', '_c25', '_c26', '_c27', '_c28', '_c29', '_c30', '_c31', '_c32', '_c33', '_c34', '_c35', '_c36', '_c37', '_c38', '_c39', '_c40', '_c41', '_c42', '_c43', '_c44', '_c45', '_c46', '_c47', '_c48', '_c49', '_c50', '_c51', '_c52', '_c53', '_c54', '_c55', '_c56', '_c57', '_c58', '_c59', '_c60', '_c61', '_c62', '_c63', '_c64', '_c65', '_c66', '_c67', '_c68', '_c69', '_c70', '_c71', '_c72', '_c73', '_c74', '_c75', '_c76', '_c77', '_c78', '_c79', '_c80', '_c81', '_c82', '_c83', '_c84', '_c85', '_c86', '_c87', '_c88', '_c89', '_c90', '_c91', '_c92', '_c93', '_c94', '_c95', '_c96', '_c97', '_c98', '_c99', '_c100', '_c101', '_c102', '_c103', '_c104', '_c105', '_c106', '_c107', '_c108', '_c109', '_c110', '_c111', '_c112', '_c113', '_c114', '_c115', '_c116', '_c117', '_c118', '_c119', '_c120', '_c121', '_c122', '_c123', '_c124', '_c125', '_c126', '_c127', '_c128', '_c129', '_c130', '_c131', '_c132', '_c133', '_c134', '_c135', '_c136', '_c137', '_c138', '_c139', '_c140', '_c141', '_c142', '_c143', '_c144', '_c145', '_c146', '_c147', '_c148', '_c149', '_c150', '_c151', '_c152', '_c153', '_c154', '_c155', '_c156', '_c157', '_c158', '_c159', '_c160', '_c161', '_c162', '_c163', '_c164', '_c165', '_c166', '_c167', '_c168', '_c169', '_c170', '_c171', '_c172', '_c173', '_c174', '_c175', '_c176', '_c177', '_c178', '_c179', '_c180', '_c181', '_c182', '_c183', '_c184', '_c185', '_c186', '_c187', '_c188', '_c189', '_c190', '_c191', '_c192', '_c193', '_c194', '_c195', '_c196', '_c197', '_c198', '_c199', '_c200', '_c201', '_c202', '_c203', '_c204', '_c205', '_c206', '_c207', '_c208', '_c209', '_c210', '_c211', '_c212', '_c213', '_c214', '_c215', '_c216', '_c217', '_c218', '_c219', '_c220', '_c221', '_c222', '_c223', '_c224', '_c225', '_c226', '_c227', '_c228', '_c229', '_c230', '_c231', '_c232', '_c233', '_c234', '_c235', '_c236', '_c237', '_c238', '_c239', '_c240', '_c241', '_c242', '_c243', '_c244', '_c245', '_c246', '_c247', '_c248', '_c249', '_c250', '_c251', '_c252', '_c253', '_c254', '_c255', '_c256', '_c257', '_c258', '_c259', '_c260', '_c261', '_c262', '_c263', '_c264', '_c265', '_c266', '_c267', '_c268', '_c269', '_c270', '_c271', '_c272', '_c273', '_c274', '_c275', '_c276', '_c277', '_c278', '_c279', '_c280', '_c281', '_c282', '_c283', '_c284', '_c285', '_c286', '_c287', '_c288', '_c289', '_c290', '_c291', '_c292', '_c293', '_c294', '_c295', '_c296', '_c297', '_c298', '_c299', '_c300', '_c301', '_c302', '_c303', '_c304', '_c305', '_c306', '_c307', '_c308', '_c309', '_c310', '_c311', '_c312', '_c313', '_c314', '_c315', '_c316', '_c317', '_c318', '_c319', '_c320', '_c321', '_c322', '_c323', '_c324', '_c325', '_c326', '_c327', '_c328', '_c329', '_c330', '_c331', '_c332', '_c333', '_c334', '_c335', '_c336', '_c337', '_c338', '_c339', '_c340', '_c341', '_c342', '_c343', '_c344', '_c345', '_c346', '_c347', '_c348', '_c349', '_c350', '_c351', '_c352', '_c353', '_c354', '_c355', '_c356', '_c357', '_c358', '_c359', '_c360', '_c361', '_c362', '_c363', '_c364', '_c365', '_c366', '_c367', '_c368', '_c369', '_c370', '_c371', '_c372', '_c373', '_c374', '_c375', '_c376', '_c377', '_c378', '_c379', '_c380', '_c381', '_c382', '_c383', '_c384', '_c385', '_c386', '_c387', '_c388', '_c389', '_c390', '_c391', '_c392', '_c393', '_c394', '_c395', '_c396', '_c397', '_c398', '_c399', '_c400', '_c401', '_c402', '_c403', '_c404', '_c405', '_c406', '_c407', '_c408', '_c409', '_c410', '_c411', '_c412', '_c413', '_c414', '_c415', '_c416', '_c417', '_c418', '_c419', '_c420', '_c421', '_c422', '_c423', '_c424', '_c425', '_c426', '_c427', '_c428', '_c429', '_c430', '_c431', '_c432', '_c433', '_c434', '_c435', '_c436', '_c437', '_c438', '_c439', '_c440', '_c441', '_c442', '_c443', '_c444', '_c445', '_c446', '_c447', '_c448', '_c449', '_c450', '_c451', '_c452', '_c453', '_c454', '_c455', '_c456', '_c457', '_c458', '_c459', '_c460', '_c461', '_c462', '_c463', '_c464', '_c465', '_c466', '_c467', '_c468', '_c469', '_c470', '_c471', '_c472', '_c473', '_c474', '_c475', '_c476', '_c477', '_c478', '_c479', '_c480', '_c481', '_c482', '_c483', '_c484', '_c485', '_c486', '_c487', '_c488', '_c489', '_c490', '_c491', '_c492', '_c493', '_c494', '_c495', '_c496', '_c497', '_c498', '_c499', '_c500', '_c501', '_c502', '_c503', '_c504', '_c505', '_c506', '_c507', '_c508', '_c509', '_c510', '_c511', '_c512', '_c513', '_c514', '_c515', '_c516', '_c517', '_c518', '_c519', '_c520', '_c521', '_c522', '_c523', '_c524', '_c525', '_c526', '_c527', '_c528', '_c529', '_c530', '_c531', '_c532', '_c533', '_c534', '_c535', '_c536', '_c537', '_c538', '_c539', '_c540', '_c541', '_c542', '_c543', '_c544', '_c545', '_c546', '_c547', '_c548', '_c549', '_c550', '_c551', '_c552', '_c553', '_c554', '_c555', '_c556', '_c557', '_c558', '_c559', '_c560', '_c561', '_c562', '_c563', '_c564', '_c565', '_c566', '_c567', '_c568', '_c569', '_c570', '_c571', '_c572', '_c573', '_c574', '_c575', '_c576', '_c577', '_c578', '_c579', '_c580', '_c581', '_c582', '_c583', '_c584', '_c585', '_c586', '_c587', '_c588', '_c589', '_c590', '_c591', '_c592', '_c593', '_c594', '_c595', '_c596', '_c597', '_c598', '_c599', '_c600', '_c601', '_c602', '_c603', '_c604', '_c605', '_c606', '_c607', '_c608', '_c609', '_c610', '_c611', '_c612', '_c613', '_c614', '_c615', '_c616', '_c617', '_c618', '_c619', '_c620', '_c621', '_c622', '_c623', '_c624', '_c625', '_c626', '_c627', '_c628', '_c629', '_c630', '_c631', '_c632', '_c633', '_c634', '_c635', '_c636', '_c637', '_c638', '_c639', '_c640', '_c641', '_c642', '_c643', '_c644', '_c645', '_c646', '_c647', '_c648', '_c649', '_c650', '_c651', '_c652', '_c653', '_c654', '_c655', '_c656', '_c657', '_c658', '_c659', '_c660', '_c661', '_c662', '_c663', '_c664', '_c665', '_c666', '_c667', '_c668', '_c669', '_c670', '_c671', '_c672', '_c673', '_c674', '_c675', '_c676', '_c677', '_c678', '_c679', '_c680', '_c681', '_c682', '_c683', '_c684', '_c685', '_c686', '_c687', '_c688', '_c689', '_c690', '_c691', '_c692', '_c693', '_c694', '_c695', '_c696', '_c697', '_c698', '_c699', '_c700', '_c701', '_c702', '_c703', '_c704', '_c705', '_c706', '_c707', '_c708', '_c709', '_c710', '_c711', '_c712', '_c713', '_c714', '_c715', '_c716', '_c717', '_c718', '_c719', '_c720', '_c721', '_c722', '_c723', '_c724', '_c725', '_c726', '_c727', '_c728', '_c729', '_c730', '_c731', '_c732', '_c733', '_c734', '_c735', '_c736', '_c737', '_c738', '_c739', '_c740', '_c741', '_c742', '_c743', '_c744', '_c745', '_c746', '_c747', '_c748', '_c749', '_c750', '_c751', '_c752', '_c753', '_c754', '_c755', '_c756', '_c757', '_c758', '_c759', '_c760', '_c761', '_c762', '_c763', '_c764', '_c765', '_c766', '_c767', '_c768', '_c769', '_c770', '_c771', '_c772', '_c773', '_c774', '_c775', '_c776', '_c777', '_c778', '_c779', '_c780', '_c781', '_c782', '_c783', '_c784']

In [18]:
feature_culumns = ["_c" + str(i+1) for i in range(784)]
print(feature_culumns)


['_c1', '_c2', '_c3', '_c4', '_c5', '_c6', '_c7', '_c8', '_c9', '_c10', '_c11', '_c12', '_c13', '_c14', '_c15', '_c16', '_c17', '_c18', '_c19', '_c20', '_c21', '_c22', '_c23', '_c24', '_c25', '_c26', '_c27', '_c28', '_c29', '_c30', '_c31', '_c32', '_c33', '_c34', '_c35', '_c36', '_c37', '_c38', '_c39', '_c40', '_c41', '_c42', '_c43', '_c44', '_c45', '_c46', '_c47', '_c48', '_c49', '_c50', '_c51', '_c52', '_c53', '_c54', '_c55', '_c56', '_c57', '_c58', '_c59', '_c60', '_c61', '_c62', '_c63', '_c64', '_c65', '_c66', '_c67', '_c68', '_c69', '_c70', '_c71', '_c72', '_c73', '_c74', '_c75', '_c76', '_c77', '_c78', '_c79', '_c80', '_c81', '_c82', '_c83', '_c84', '_c85', '_c86', '_c87', '_c88', '_c89', '_c90', '_c91', '_c92', '_c93', '_c94', '_c95', '_c96', '_c97', '_c98', '_c99', '_c100', '_c101', '_c102', '_c103', '_c104', '_c105', '_c106', '_c107', '_c108', '_c109', '_c110', '_c111', '_c112', '_c113', '_c114', '_c115', '_c116', '_c117', '_c118', '_c119', '_c120', '_c121', '_c122', '_c123', '_c124', '_c125', '_c126', '_c127', '_c128', '_c129', '_c130', '_c131', '_c132', '_c133', '_c134', '_c135', '_c136', '_c137', '_c138', '_c139', '_c140', '_c141', '_c142', '_c143', '_c144', '_c145', '_c146', '_c147', '_c148', '_c149', '_c150', '_c151', '_c152', '_c153', '_c154', '_c155', '_c156', '_c157', '_c158', '_c159', '_c160', '_c161', '_c162', '_c163', '_c164', '_c165', '_c166', '_c167', '_c168', '_c169', '_c170', '_c171', '_c172', '_c173', '_c174', '_c175', '_c176', '_c177', '_c178', '_c179', '_c180', '_c181', '_c182', '_c183', '_c184', '_c185', '_c186', '_c187', '_c188', '_c189', '_c190', '_c191', '_c192', '_c193', '_c194', '_c195', '_c196', '_c197', '_c198', '_c199', '_c200', '_c201', '_c202', '_c203', '_c204', '_c205', '_c206', '_c207', '_c208', '_c209', '_c210', '_c211', '_c212', '_c213', '_c214', '_c215', '_c216', '_c217', '_c218', '_c219', '_c220', '_c221', '_c222', '_c223', '_c224', '_c225', '_c226', '_c227', '_c228', '_c229', '_c230', '_c231', '_c232', '_c233', '_c234', '_c235', '_c236', '_c237', '_c238', '_c239', '_c240', '_c241', '_c242', '_c243', '_c244', '_c245', '_c246', '_c247', '_c248', '_c249', '_c250', '_c251', '_c252', '_c253', '_c254', '_c255', '_c256', '_c257', '_c258', '_c259', '_c260', '_c261', '_c262', '_c263', '_c264', '_c265', '_c266', '_c267', '_c268', '_c269', '_c270', '_c271', '_c272', '_c273', '_c274', '_c275', '_c276', '_c277', '_c278', '_c279', '_c280', '_c281', '_c282', '_c283', '_c284', '_c285', '_c286', '_c287', '_c288', '_c289', '_c290', '_c291', '_c292', '_c293', '_c294', '_c295', '_c296', '_c297', '_c298', '_c299', '_c300', '_c301', '_c302', '_c303', '_c304', '_c305', '_c306', '_c307', '_c308', '_c309', '_c310', '_c311', '_c312', '_c313', '_c314', '_c315', '_c316', '_c317', '_c318', '_c319', '_c320', '_c321', '_c322', '_c323', '_c324', '_c325', '_c326', '_c327', '_c328', '_c329', '_c330', '_c331', '_c332', '_c333', '_c334', '_c335', '_c336', '_c337', '_c338', '_c339', '_c340', '_c341', '_c342', '_c343', '_c344', '_c345', '_c346', '_c347', '_c348', '_c349', '_c350', '_c351', '_c352', '_c353', '_c354', '_c355', '_c356', '_c357', '_c358', '_c359', '_c360', '_c361', '_c362', '_c363', '_c364', '_c365', '_c366', '_c367', '_c368', '_c369', '_c370', '_c371', '_c372', '_c373', '_c374', '_c375', '_c376', '_c377', '_c378', '_c379', '_c380', '_c381', '_c382', '_c383', '_c384', '_c385', '_c386', '_c387', '_c388', '_c389', '_c390', '_c391', '_c392', '_c393', '_c394', '_c395', '_c396', '_c397', '_c398', '_c399', '_c400', '_c401', '_c402', '_c403', '_c404', '_c405', '_c406', '_c407', '_c408', '_c409', '_c410', '_c411', '_c412', '_c413', '_c414', '_c415', '_c416', '_c417', '_c418', '_c419', '_c420', '_c421', '_c422', '_c423', '_c424', '_c425', '_c426', '_c427', '_c428', '_c429', '_c430', '_c431', '_c432', '_c433', '_c434', '_c435', '_c436', '_c437', '_c438', '_c439', '_c440', '_c441', '_c442', '_c443', '_c444', '_c445', '_c446', '_c447', '_c448', '_c449', '_c450', '_c451', '_c452', '_c453', '_c454', '_c455', '_c456', '_c457', '_c458', '_c459', '_c460', '_c461', '_c462', '_c463', '_c464', '_c465', '_c466', '_c467', '_c468', '_c469', '_c470', '_c471', '_c472', '_c473', '_c474', '_c475', '_c476', '_c477', '_c478', '_c479', '_c480', '_c481', '_c482', '_c483', '_c484', '_c485', '_c486', '_c487', '_c488', '_c489', '_c490', '_c491', '_c492', '_c493', '_c494', '_c495', '_c496', '_c497', '_c498', '_c499', '_c500', '_c501', '_c502', '_c503', '_c504', '_c505', '_c506', '_c507', '_c508', '_c509', '_c510', '_c511', '_c512', '_c513', '_c514', '_c515', '_c516', '_c517', '_c518', '_c519', '_c520', '_c521', '_c522', '_c523', '_c524', '_c525', '_c526', '_c527', '_c528', '_c529', '_c530', '_c531', '_c532', '_c533', '_c534', '_c535', '_c536', '_c537', '_c538', '_c539', '_c540', '_c541', '_c542', '_c543', '_c544', '_c545', '_c546', '_c547', '_c548', '_c549', '_c550', '_c551', '_c552', '_c553', '_c554', '_c555', '_c556', '_c557', '_c558', '_c559', '_c560', '_c561', '_c562', '_c563', '_c564', '_c565', '_c566', '_c567', '_c568', '_c569', '_c570', '_c571', '_c572', '_c573', '_c574', '_c575', '_c576', '_c577', '_c578', '_c579', '_c580', '_c581', '_c582', '_c583', '_c584', '_c585', '_c586', '_c587', '_c588', '_c589', '_c590', '_c591', '_c592', '_c593', '_c594', '_c595', '_c596', '_c597', '_c598', '_c599', '_c600', '_c601', '_c602', '_c603', '_c604', '_c605', '_c606', '_c607', '_c608', '_c609', '_c610', '_c611', '_c612', '_c613', '_c614', '_c615', '_c616', '_c617', '_c618', '_c619', '_c620', '_c621', '_c622', '_c623', '_c624', '_c625', '_c626', '_c627', '_c628', '_c629', '_c630', '_c631', '_c632', '_c633', '_c634', '_c635', '_c636', '_c637', '_c638', '_c639', '_c640', '_c641', '_c642', '_c643', '_c644', '_c645', '_c646', '_c647', '_c648', '_c649', '_c650', '_c651', '_c652', '_c653', '_c654', '_c655', '_c656', '_c657', '_c658', '_c659', '_c660', '_c661', '_c662', '_c663', '_c664', '_c665', '_c666', '_c667', '_c668', '_c669', '_c670', '_c671', '_c672', '_c673', '_c674', '_c675', '_c676', '_c677', '_c678', '_c679', '_c680', '_c681', '_c682', '_c683', '_c684', '_c685', '_c686', '_c687', '_c688', '_c689', '_c690', '_c691', '_c692', '_c693', '_c694', '_c695', '_c696', '_c697', '_c698', '_c699', '_c700', '_c701', '_c702', '_c703', '_c704', '_c705', '_c706', '_c707', '_c708', '_c709', '_c710', '_c711', '_c712', '_c713', '_c714', '_c715', '_c716', '_c717', '_c718', '_c719', '_c720', '_c721', '_c722', '_c723', '_c724', '_c725', '_c726', '_c727', '_c728', '_c729', '_c730', '_c731', '_c732', '_c733', '_c734', '_c735', '_c736', '_c737', '_c738', '_c739', '_c740', '_c741', '_c742', '_c743', '_c744', '_c745', '_c746', '_c747', '_c748', '_c749', '_c750', '_c751', '_c752', '_c753', '_c754', '_c755', '_c756', '_c757', '_c758', '_c759', '_c760', '_c761', '_c762', '_c763', '_c764', '_c765', '_c766', '_c767', '_c768', '_c769', '_c770', '_c771', '_c772', '_c773', '_c774', '_c775', '_c776', '_c777', '_c778', '_c779', '_c780', '_c781', '_c782', '_c783', '_c784']

In [19]:
from pyspark.ml.feature import VectorAssembler

In [20]:
vectorizer = VectorAssembler(inputCols=feature_culumns, outputCol="features")
training = (vectorizer
            .transform(df_training)
            .select("_c0", "features")
            .toDF("label", "features")
            .cache())
training.show()


+-----+--------------------+
|label|            features|
+-----+--------------------+
|    5|(784,[152,153,154...|
|    0|(784,[127,128,129...|
|    4|(784,[160,161,162...|
|    1|(784,[158,159,160...|
|    9|(784,[208,209,210...|
|    2|(784,[155,156,157...|
|    1|(784,[124,125,126...|
|    3|(784,[151,152,153...|
|    1|(784,[152,153,154...|
|    4|(784,[134,135,161...|
|    3|(784,[123,124,125...|
|    5|(784,[216,217,218...|
|    3|(784,[143,144,145...|
|    6|(784,[72,73,74,99...|
|    1|(784,[151,152,153...|
|    7|(784,[211,212,213...|
|    2|(784,[151,152,153...|
|    8|(784,[159,160,161...|
|    6|(784,[100,101,102...|
|    9|(784,[209,210,211...|
+-----+--------------------+
only showing top 20 rows


In [26]:
a = training.first().features.toArray()
type(a)


Out[26]:
numpy.ndarray

In [29]:
plt.imshow(a.reshape(28, 28), cmap="Greys")


Out[29]:
<matplotlib.image.AxesImage at 0x11b842080>

In [40]:
images = training.sample(False, 0.01, 1).take(25)
fig, _ = plt.subplots(5, 5, figsize = (10, 10))
for i, ax in enumerate(fig.axes):
    r = images[i]
    label = r.label
    features = r.features
    ax.imshow(features.toArray().reshape(28, 28), cmap = "Greys")
    ax.set_title("True: " + str(label))

plt.tight_layout()



In [74]:
counts = training.groupBy("label").count()

In [76]:
counts_df = counts.rdd.map(lambda r: {"label": r['label'], 
                                     "count": r['count']}).collect()
pd.DataFrame(counts_df).set_index("label").sort_index().plot.bar()


Out[76]:
<matplotlib.axes._subplots.AxesSubplot at 0x120932cf8>

In [55]:
df_testing = (spark
              .read
              .options(header = False, inferSchema = True)
              .csv("data/MNIST/mnist_test.csv"))
testing = (vectorizer
           .transform(df_testing)
           .select("_c0", "features")
           .toDF("label", "features")
           .cache())

In [56]:
from pyspark.ml.classification import LogisticRegression

In [57]:
lr = LogisticRegression(featuresCol="features", 
                        labelCol="label", 
                        regParam=0.1, 
                        elasticNetParam=0.1, 
                        maxIter=10000)

In [58]:
lr_model = lr.fit(training)

In [65]:
from pyspark.sql.functions import *

In [67]:
test_pred = lr_model.transform(testing).withColumn("matched", expr("label == prediction"))
test_pred.show()


+-----+--------------------+--------------------+--------------------+----------+-------+
|label|            features|       rawPrediction|         probability|prediction|matched|
+-----+--------------------+--------------------+--------------------+----------+-------+
|    7|(784,[202,203,204...|[0.08070480165374...|[0.01163236392094...|       7.0|   true|
|    2|(784,[94,95,96,97...|[1.25425406358765...|[0.02467764471716...|       2.0|   true|
|    1|(784,[128,129,130...|[-1.2276524471687...|[0.00751263994471...|       1.0|   true|
|    0|(784,[124,125,126...|[3.71476062530585...|[0.85536145424016...|       0.0|   true|
|    4|(784,[150,151,159...|[-0.2039270192761...|[0.04549113250992...|       4.0|   true|
|    1|(784,[156,157,158...|[-1.8454996939466...|[0.00278444647289...|       1.0|   true|
|    4|(784,[149,150,151...|[-1.7871799792661...|[0.00965258005973...|       4.0|   true|
|    9|(784,[179,180,181...|[-2.3331144616742...|[0.00711927716793...|       9.0|   true|
|    5|(784,[129,130,131...|[0.25089606477712...|[0.06160565780450...|       5.0|   true|
|    9|(784,[209,210,211...|[-0.7976271976762...|[0.02088880010781...|       9.0|   true|
|    0|(784,[123,124,125...|[3.87665582702822...|[0.80110751027447...|       0.0|   true|
|    6|(784,[94,95,96,97...|[1.89086135619814...|[0.28924960656629...|       0.0|  false|
|    9|(784,[208,209,210...|[-1.4656858164681...|[0.00902667735258...|       9.0|   true|
|    0|(784,[152,153,154...|[3.52162978716862...|[0.84049195109150...|       0.0|   true|
|    1|(784,[125,126,127...|[-2.2951507253657...|[0.00159973133619...|       1.0|   true|
|    5|(784,[124,125,126...|[0.26443714870398...|[0.07367353100470...|       5.0|   true|
|    9|(784,[179,180,181...|[-0.6255916188595...|[0.02550361387932...|       9.0|   true|
|    7|(784,[200,201,202...|[0.70990611838535...|[0.02075121690891...|       7.0|   true|
|    3|(784,[118,119,120...|[-1.0216859510239...|[0.01042092765356...|       3.0|   true|
|    4|(784,[158,159,185...|[-1.2147078347317...|[0.01415523283792...|       4.0|   true|
+-----+--------------------+--------------------+--------------------+----------+-------+
only showing top 20 rows


In [61]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [62]:
evaluator = MulticlassClassificationEvaluator(labelCol="label", 
                                               predictionCol="prediction", 
                                               metricName="accuracy")

In [63]:
evaluator.evaluate(test_pred)


Out[63]:
0.8729

In [70]:
(test_pred
 .withColumn("matched", expr("cast(matched as int)"))
 .groupby("label")
 .agg(avg("matched"))
 .orderBy("label")
 .show())


+-----+------------------+
|label|      avg(matched)|
+-----+------------------+
|    0|0.9622448979591837|
|    1|0.9718061674008811|
|    2|0.8275193798449613|
|    3|0.8772277227722772|
|    4|0.8940936863543788|
|    5|0.7600896860986547|
|    6|0.9050104384133612|
|    7|0.8735408560311284|
|    8|0.7895277207392197|
|    9|0.8453914767096135|
+-----+------------------+

Classifying MNIST using Neural Networks


In [15]:
from pyspark.ml.classification import MultilayerPerceptronClassifier

In [16]:
layers = [784, 100, 20, 10]
perceptron = MultilayerPerceptronClassifier(maxIter=1000, layers=layers, blockSize=128, seed=1234)
perceptron_model = perceptron.fit(training)

In [ ]:
from time import time

In [ ]:
start_time = time()
perceptron_model = perceptron.fit(training)
test_pred = perceptron_model.transform(testing)
print("Accuracy:", evaluator.evaluate(test_pred))
print("Time taken: %d" % (time() - start_time))

In [ ]: